package com.xiam.consia.ml.classifiers;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.xiam.consia.algs.predict.property.PropertyManager;
import com.xiam.consia.ml.attributeselection.AttributeSelection;
import com.xiam.consia.ml.classifiers.ClassifierConstants;
import com.xiam.consia.ml.data.DataRecord;
import com.xiam.consia.ml.data.DataRecords;
import com.xiam.consia.ml.data.ProbResults;
import com.xiam.consia.ml.data.attribute.Attribute;
import com.xiam.snapdragon.app.data.property.PropertyConstants;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: classes.dex */
public class NaiveBayes extends Classifier {
    private double class1Weight;
    private double class2Weight;
    private Map<String, Integer> classCounts;
    private Map<String, Map<String, Map<String, Integer>>> dataModel;
    private int numberTrainingRecords;

    public NaiveBayes(PropertyManager propertyManager, ClassifierConstants.PredictionType predictionType, String str) {
        super(str, predictionType.getClassCount());
        this.numberTrainingRecords = 0;
        this.class1Weight = 1.0d;
        this.class2Weight = 1.0d;
        loadProperties(predictionType, propertyManager);
    }

    private static double estimateInterval(List<Double> list) {
        Collections.sort(list);
        int i = 0;
        double d = 0.0d;
        for (int i2 = 0; i2 < list.size() - 1; i2++) {
            double abs = Math.abs(list.get(i2 + 1).doubleValue() - list.get(i2).doubleValue());
            if (abs > 0.0d) {
                d += abs;
                i++;
            }
        }
        return d / (i * 2.0d);
    }

    private double getGaussianProb(Attribute attribute, double d, String str) {
        double d2;
        if (attribute.getContinuousSampleMeansPerClass().get(str) != null) {
            double doubleValue = attribute.getContinuousSampleMeansPerClass().get(str).doubleValue();
            double sqrt = Math.sqrt(attribute.getContinuousSampleVariancesPerClass().get(str).doubleValue());
            d2 = standardNormalApproximation(((attribute.getSampleInterval() + d) - doubleValue) / sqrt) - standardNormalApproximation(((d - attribute.getSampleInterval()) - doubleValue) / sqrt);
        } else {
            d2 = 0.0d;
        }
        if (d2 > 0.0d) {
            return d2;
        }
        return 1.0E-75d;
    }

    private void getSampleStatisticsForCtsAttributes(List<Attribute> list) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= list.size()) {
                return;
            }
            Attribute attribute = list.get(i2);
            if (!attribute.isDiscrete) {
                ArrayList newArrayList = Lists.newArrayList();
                Lists.newArrayList();
                for (String str : this.classCounts.keySet()) {
                    if (attribute.getContinuousSampleValuesPerClass().get(str) != null) {
                        List<Double> list2 = attribute.getContinuousSampleValuesPerClass().get(str);
                        newArrayList.addAll(list2);
                        double d = 0.0d;
                        int i3 = 0;
                        while (true) {
                            int i4 = i3;
                            if (i4 >= list2.size()) {
                                break;
                            }
                            d += list2.get(i4).doubleValue();
                            i3 = i4 + 1;
                        }
                        double size = d / list2.size();
                        double d2 = 0.0d;
                        for (int i5 = 0; i5 < list2.size(); i5++) {
                            d2 += Math.pow(list2.get(i5).doubleValue() - size, 2.0d);
                        }
                        attribute.getContinuousSampleMeansPerClass().put(str, Double.valueOf(size));
                        attribute.getContinuousSampleVariancesPerClass().put(str, Double.valueOf(d2 / list2.size()));
                    }
                }
                if (newArrayList.size() > 0) {
                    attribute.setSampleInterval(estimateInterval(newArrayList));
                }
            }
            i = i2 + 1;
        }
    }

    @Override // com.xiam.consia.ml.classifiers.Classifier
    public void buildClassifier(DataRecords dataRecords, AttributeSelection attributeSelection) {
        this.classCounts = Maps.newHashMapWithExpectedSize(25);
        this.numberTrainingRecords = dataRecords.getNumRecords();
        this.dataModel = Maps.newHashMapWithExpectedSize(1000);
        this.classCounts.put("yes", 0);
        this.classCounts.put("no", 0);
        for (int i = 0; i < dataRecords.getNumRecords(); i++) {
            DataRecord dataRecord = dataRecords.getDataRecords().get(i);
            String classLabel = dataRecord.getClassLabel();
            Integer num = this.classCounts.get(classLabel);
            if (num == null) {
                num = 0;
            }
            this.classCounts.put(classLabel, Integer.valueOf(num.intValue() + 1));
            Integer.valueOf(0);
            List<String> attributeValues = dataRecord.getAttributeValues();
            for (int i2 = 0; i2 < dataRecords.getNumAttributes(); i2++) {
                String str = attributeValues.get(i2);
                if (dataRecords.getAttributes().get(i2).isDiscrete) {
                    if (this.dataModel.get(dataRecords.getAttributes().get(i2).getName()) == null) {
                        this.dataModel.put(dataRecords.getAttributes().get(i2).getName(), Maps.newHashMap());
                    }
                    if (this.dataModel.get(dataRecords.getAttributes().get(i2).getName()).get(str) == null) {
                        this.dataModel.get(dataRecords.getAttributes().get(i2).getName()).put(str, Maps.newHashMap());
                    }
                    Integer num2 = this.dataModel.get(dataRecords.getAttributes().get(i2).getName()).get(str).get(classLabel);
                    this.dataModel.get(dataRecords.getAttributes().get(i2).getName()).get(str).put(classLabel, Integer.valueOf((num2 == null ? 0 : num2).intValue() + 1));
                } else {
                    List<Double> list = dataRecords.getAttributes().get(i2).getContinuousSampleValuesPerClass().get(classLabel);
                    if (list == null) {
                        ArrayList newArrayList = Lists.newArrayList();
                        newArrayList.add(Double.valueOf(Double.parseDouble(str)));
                        dataRecords.getAttributes().get(i2).getContinuousSampleValuesPerClass().put(classLabel, newArrayList);
                    } else {
                        list.add(Double.valueOf(Double.parseDouble(str)));
                    }
                }
            }
            this.numberTrainingRecords++;
        }
        getSampleStatisticsForCtsAttributes(dataRecords.getAttributes());
    }

    @Override // com.xiam.consia.ml.classifiers.Classifier
    public ProbResults classify(DataRecord dataRecord) {
        double d;
        double d2 = 0.0d;
        String str = PropertyConstants.SDA_CIID_DEFAULT;
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(50);
        newHashMapWithExpectedSize.put("no", Double.valueOf(0.0d));
        newHashMapWithExpectedSize.put("yes", Double.valueOf(0.0d));
        double d3 = 0.0d;
        for (String str2 : this.classCounts.keySet()) {
            double intValue = ((1.0d * this.classCounts.get(str2).intValue()) + 1.0d) / ((1.0d * this.numberTrainingRecords) + this.classCounts.size());
            int i = 0;
            while (true) {
                int i2 = i;
                if (i2 >= dataRecord.getAttributes().size()) {
                    break;
                }
                String str3 = dataRecord.getAttributeValues().get(i2);
                if (!dataRecord.getAttributes().get(i2).isDiscrete) {
                    intValue *= getGaussianProb(dataRecord.getAttributes().get(i2), Double.parseDouble(str3), str2);
                } else if (!str3.equalsIgnoreCase("null") && !str3.equalsIgnoreCase("NA")) {
                    intValue *= ((1.0d * ((this.dataModel.get(dataRecord.getAttributes().get(i2).getName()).get(str3) == null || this.dataModel.get(dataRecord.getAttributes().get(i2).getName()).get(str3).get(str2) == null) ? 0 : this.dataModel.get(dataRecord.getAttributes().get(i2).getName()).get(str3).get(str2).intValue())) + 1.0d) / (dataRecord.getAttributes().get(i2).getAllowedDiscreteValues().size() + this.classCounts.get(str2).intValue());
                }
                i = i2 + 1;
            }
            newHashMapWithExpectedSize.put(str2, Double.valueOf(intValue));
            d3 += intValue;
            if (intValue > d2) {
                d = intValue;
            } else {
                str2 = str;
                d = d2;
            }
            d2 = d;
            str = str2;
        }
        ProbResults probResults = new ProbResults(str, d2 / d3);
        if (this.classCounts.size() > 2) {
            return probResults;
        }
        double intValue2 = (((this.classCounts.get("no") != null ? this.classCounts.get("no").intValue() : 0) * 1.0d) + 1.0d) / (this.numberOfClasses + d3);
        double d4 = 1.0d - intValue2;
        double d5 = intValue2 * this.class2Weight;
        double d6 = d4 * this.class1Weight;
        double d7 = d6 + d5;
        double d8 = d5 / d7;
        double d9 = d6 / d7;
        return d8 > d9 ? new ProbResults("no", d8) : new ProbResults("yes", d9);
    }

    @Override // com.xiam.consia.ml.classifiers.Classifier
    public void loadProperties(ClassifierConstants.PredictionType predictionType, PropertyManager propertyManager) {
        if (predictionType == ClassifierConstants.PredictionType.APP) {
            this.class1Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_APP));
            this.class2Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_APP));
        } else if (predictionType == ClassifierConstants.PredictionType.PHONEON) {
            this.class1Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_PHONEON));
            this.class2Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_PHONEON));
        } else if (predictionType == ClassifierConstants.PredictionType.PLACEMOVE) {
            this.class1Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS1_WEIGHT_PLACEMOVE));
            this.class2Weight = Double.parseDouble(propertyManager.getStringProperty(com.xiam.consia.data.constants.PropertyConstants.PREDICT_RF_CLASS2_WEIGHT_PLACEMOVE));
        }
    }

    public double standardNormalApproximation(double d) {
        if (d < 0.0d) {
            return 1.0d - standardNormalApproximation((-1.0d) * d);
        }
        double d2 = 1.0d / ((0.2316419d * d) + 1.0d);
        return 1.0d - (((d2 * ((-0.356563782d) + ((1.781477937d + (((-1.821255978d) + (1.330274429d * d2)) * d2)) * d2))) + 0.31938153d) * (((1.0d / Math.sqrt(6.283185307179586d)) * Math.exp(((-1.0d) * (d * d)) / 2.0d)) * d2));
    }
}
